import logging
import yaml
from functools import partial

from bluepyparallel import init_parallel_factory
from emodel_generalisation.model.modifiers import synth_soma, synth_axon
from emodel_generalisation.model.access_point import AccessPoint
from emodel_generalisation.mcmc import run_several_chains
import extra_features

logger = logging.getLogger()
logging.basicConfig(level=logging.INFO)

if __name__ == "__main__":
    #parallel_factory = init_parallel_factory("multiprocessing")
    parallel_factory = init_parallel_factory("dask_dataframe")
    emodel = "simplest"

    access_point = AccessPoint(
        emodel_dir=".",
        recipes_path="config/recipes.json",
        # final_path="final.json",
    )
    exemplar_data = yaml.safe_load(open("exemplar_data.yaml"))
    access_point.morph_path = exemplar_data["paths"]["all"]

    access_point.settings["morph_modifiers"] = [
        partial(synth_soma, params=exemplar_data["soma"], scale=1.0),
        partial(synth_axon, params=exemplar_data["ais"]["popt"], scale=1.0),
    ]

    run_several_chains(
        proposal_params={"type": "normal", "std": 0.1},  # increase std to propose larger jumps
        temperature=0.5,  # increase to explore higher cost values
        n_steps=20000,  # max number of steps, set high so it will do max
        n_chains=1000,  # set to number of cpu in one node
        emodel=emodel,
        access_point=access_point,
        run_df_path="run_df.csv",
        results_df_path="chains",
        parallel_lib=parallel_factory,
    )
